package com.idea.relax.boot.file;

import org.springframework.core.MethodParameter;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.method.support.ModelAndViewContainer;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;

import javax.annotation.Nonnull;
import java.util.ArrayList;
import java.util.List;

/**
 * @author: 沉香
 * @date: 2023/4/21
 * @description:
 */
public class FileArgumentResolver implements HandlerMethodArgumentResolver {

	@Override
	public boolean supportsParameter(@Nonnull MethodParameter parameter) {
		return isFileObjectType(parameter) || isFileObjectArrayType(parameter) || isFileObjectsType(parameter);
	}

	@Override
	public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer,
								  NativeWebRequest webRequest, WebDataBinderFactory binderFactory) throws Exception {
		String parameterName = parameter.getParameterName();
		MultipartHttpServletRequest multipartRequest = webRequest.getNativeRequest(MultipartHttpServletRequest.class);
		Assert.notNull(multipartRequest, "获取MultipartHttpServletRequest对象失败！");
		MultiValueMap<String, MultipartFile> multiFileMap = multipartRequest.getMultiFileMap();
		List<MultipartFile> multipartFiles = multiFileMap.get(parameterName);
		if (null == multipartFiles || multipartFiles.isEmpty()) {
			return null;
		}
		if (isFileObjectArrayType(parameter)) {
			FileObject[] fileObjects = new FileObject[multipartFiles.size()];
			for (int i = 0; i < multipartFiles.size(); i++) {
				fileObjects[i] = new FileObject(multipartFiles.get(i));
			}
			return fileObjects;
		} else if (isFileObjectsType(parameter)) {
			List<FileObject> fileObjects = new ArrayList<>();
			for (MultipartFile multipartFile : multipartFiles) {
				fileObjects.add(new FileObject(multipartFile));
			}
			return fileObjects;
		} else {
			MultipartFile multipartFile = multipartFiles.get(0);
			return new FileObject(multipartFile);
		}
	}

	private boolean isFileObjectType(MethodParameter parameter) {
		return parameter.getParameterType().equals(FileObject.class);
	}

	private boolean isFileObjectArrayType(MethodParameter parameter) {
		return parameter.getParameterType().equals(FileObject[].class);
	}

	private boolean isFileObjectsType(MethodParameter parameter) {
		return parameter.getParameterType().equals(FileObjects.class);
	}

}
